import torch
import torch.nn as nn

class cnnsmall(nn.Module):
    def __init__(self, in_channel=3, num_classes=10, width=1, **kwargs):
        super(cnnsmall, self).__init__()
        self.features = nn.Sequential(nn.Conv2d(in_channel, 16*width, kernel_size=3, padding=1, bias=False), 
                                      nn.BatchNorm2d(16*width), 
                                      nn.ReLU(),
                                      nn.Conv2d(16*width, 16*width, kernel_size=3, padding=1, bias=False), 
                                      nn.BatchNorm2d(16*width), 
                                      nn.ReLU(), 
                                      nn.MaxPool2d(kernel_size=2, stride=2),
                                      nn.Conv2d(16*width, 32*width, kernel_size=3, padding=1, bias=False), 
                                      nn.BatchNorm2d(32*width), 
                                      nn.ReLU(), 
                                      nn.Conv2d(32*width, 32*width, kernel_size=3, padding=1, bias=False), 
                                      nn.BatchNorm2d(32*width), 
                                      nn.ReLU(), 
                                      nn.MaxPool2d(kernel_size=2, stride=2),
                                     )
        self.avgpool = nn.AdaptiveAvgPool2d((4, 4))
        # Using extra linear layer helps terrifically on MNIST
        self.classifier = nn.Sequential(nn.Linear(32*width*4*4, 32*width), nn.ReLU(), nn.Linear(32*width, num_classes)) 
        
    def forward(self, x, hidden=False):
        x = self.features(x)
        x = self.avgpool(x)
        out = torch.flatten(x, 1)
        x = self.classifier(out)
        if hidden:
            return x, out
        else:
            return x
        
        
class cnnsupersmall(nn.Module):
    def __init__(self, in_channel=3, num_classes=10, width=1, **kwargs):
        super(cnnsupersmall, self).__init__()
        self.features = nn.Sequential(nn.Conv2d(in_channel, 16*width, kernel_size=3, padding=1, bias=False), 
                                      nn.BatchNorm2d(16*width), 
                                      nn.ReLU(),
                                      nn.MaxPool2d(kernel_size=2, stride=2),
                                      nn.Conv2d(16*width, 32*width, kernel_size=3, padding=1, bias=False), 
                                      nn.BatchNorm2d(32*width), 
                                      nn.ReLU(), 
                                      nn.MaxPool2d(kernel_size=2, stride=2)
                                     )
        self.avgpool = nn.AdaptiveAvgPool2d((4, 4))
        # Using extra linear layer helps terrifically on MNIST
        self.classifier = nn.Sequential(nn.Linear(32*width*4*4, 32*width), nn.ReLU(), nn.Linear(32*width, num_classes)) 
        
    def forward(self, x, hidden=False):
        x = self.features(x)
        x = self.avgpool(x)
        out = torch.flatten(x, 1)
        x = self.classifier(out)
        if hidden:
            return x, out
        else:
            return x